# 导入标准库
import json
import os
# 导入 PyTorch
import torch
# 导入日志
from base import logger
# 导入numpy
import numpy as np
# 导入 Transformers 库
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
# 导入train_test_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix


class QueryClassifier:
	def __init__(self, model_path='bert_query_classifier'):
		# 初始化模型路径
		self.model_path = model_path
		# 加载 Bert 分词器
		self.tokenizer = BertTokenizer.from_pretrained(model_path)
		# 初始化模型
		self.model = None
		# 选择设备
		self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
		# 记录设备信息
		logger.info(f"使用设备：{self.device}")
		# 定义标签映射
		self.label_map = {"通用知识": 0, "专业咨询": 1}
		# 加载模型
		self.load_model()

	def load_model(self):
		# 检查模型路径是否存在
		if os.path.exists(self.model_path):
			# 加载预训练模型
			self.model = BertForSequenceClassification.from_pretrained(
				self.model_path,
				num_labels=2
			)
			# 将模型移动到指定设备上
			self.model.to(self.device)
			# 记录初始化模型日志信息
			logger.info(f"加载模型：{self.model_path}")
		else:
			# 初始化新模型
			self.model = BertForSequenceClassification.from_pretrained(
				'bert-base-chinese',
				num_labels=2
			)
			self.model.to(self.device)
			# 记录模型初始化日志信息
			logger.info(f"初始化新 Bert 模型")

	def save_model(self):
		"""保存模型"""
		self.model.save_pretrained(self.model_path)
		self.tokenizer.save_pretrained(self.model_path)
		logger.info(f"模型保存至：{self.model_path}")

	def preprocess_data(self, texts, labels):
		"""预处理数据为 Bert 输入格式"""
		encodings = self.tokenizer(
			texts,
			truncation=True,
			padding=True,
			max_length=128,
			return_tensors='pt'
		)
		return encodings, [self.label_map[label] for label in labels]

	def create_dataset(self, encodings, labels):
		"""创建 PyTorch 数据集"""
		class Dataset(torch.utils.data.Dataset):
			def __init__(self, encodings, labels):
				self.encodings = encodings
				self.labels = labels

			def __getitem__(self, index):
				item = {key: val[index] for key, val in self.encodings.items()}
				item['labels'] = torch.tensor(self.labels[index])
				return item

			def __len__(self):
				return len(self.labels)
		return Dataset(encodings, labels)

	def train_model(self, data_file='training_dataset_hybrid_5000.json'):
		"""训练 Bert 分类模型"""
		# 加载数据集
		if not os.path.exists(data_file):
			logger.error(f"数据集文件 {data_file} 不存在")
			raise FileNotFoundError(f"数据集文件 {data_file} 不存在")

		with open(data_file, 'r', encoding='utf-8') as f:
			data = [json.loads(values) for values in f.readlines()]

		texts = [item['query'] for item in data]
		labels = [item['label'] for item in data]

		# 数据划分
		train_texts, val_texts, train_labels,val_labels = train_test_split(
			texts, labels, test_size=0.2, random_state=42
		)

		# 预处理
		train_encodings, train_labels = self.preprocess_data(train_texts, train_labels)
		val_encodings, val_labels = self.preprocess_data(val_texts, val_labels)

		# 创建数据集
		train_dataset = self.create_dataset(train_encodings, train_labels)
		val_dataset = self.create_dataset(val_encodings, val_labels)
		# 设置训练参数
		train_args = TrainingArguments(
			output_dir='./bert_results',  # 模型和输出文件的保存目录
			num_train_epochs=3,  # 训练的从轮次
			per_device_train_batch_size=8,  # 每个设备上的训练批次大小
			per_device_eval_batch_size=8,  # 每个设备上的评估批次大小
			warmup_steps=500,  # 学习率预热步数，逐渐增加学习率
			weight_decay=0.01,  # 权重衰减系数，用于防止过拟合
			logging_dir='./bert_logs',  # TensorBoard 日志保存目录
			logging_steps=10,  # 每多少步记录一次日志
			eval_strategy='epoch',  # 评估策略：每个 epoch 结束时评估
			save_strategy='epoch',  # 保存策略：每个 epoch 结束时保存模型
			load_best_model_at_end=True,  # 训练结束后加载最佳模型
			save_total_limit=1,  # 保存模型的数量限制
			metric_for_best_model='eval_loss',  # 用于选择最佳模型的指标
			fp16=False  # 是否使用混合精度训练
		)

		# 初始化 Trainer
		trainer = Trainer(
			model=self.model,  # 训练模型
			args=train_args,  # 训练参数
			train_dataset=train_dataset,  # 训练数据集
			eval_dataset=val_dataset,  # 评估数据集
			compute_metrics=self.compute_metrics  # 计算评估指标
		)

		# 训练模型
		logger.info("开始训练 Bert 模型...")
		trainer.train()
		self.save_model()

		# 评估模型
		self.evaluate_model(val_texts, val_labels)

	def compute_metrics(self, eval_pred):
		"""计算评估指标"""
		logits, labels = eval_pred
		predictions = np.argmax(logits, axis=-1)
		accuracy = (predictions == labels).mean()
		return {"accuracy": accuracy}

	def evaluate_model(self, texts, labels):
		"""评估模型性能"""
		# 仅对 texts 进行分词，labels 已为数字
		encodings = self.tokenizer(
			texts,
			truncation=True,
			padding=True,
			max_length=128,
			return_tensors='pt'
		)
		dataset = self.create_dataset(encodings, labels)

		trainer = Trainer(model=self.model)
		predictions = trainer.predict(dataset)
		pred_labels = np.argmax(predictions.predictions, axis=-1)
		true_labels = labels

		logger.info("分类报告：")
		logger.info(classification_report(
			true_labels,
			pred_labels,
			target_names=["通用知识", "专业咨询"]
		))
		logger.info("混淆矩阵：")
		logger.info(confusion_matrix(true_labels, pred_labels))

	def predict_category(self, query):
		# 检查模型是否加载
		if self.model is None:
			logger.error("模型未训练或加载")
			return "通用知识"

		# 对查询进行编码
		encodings = self.tokenizer(
			query,
			truncation=True,
			padding=True,
			max_length=128,
			return_tensors='pt'
		)
		# 将编码移动到设备上
		encodings = {key: value.to(self.device) for key, value in encodings.items()}
		# 不计算梯度进行预测
		with torch.no_grad():
			# 获取模型输出
			output = self.model(**encodings)
			# 获取预测结果
			prediction = torch.argmax(output.logits, dim=-1).item()
		# 根据预测结果返回类别
		return "专业咨询" if prediction == 1 else "通用知识"


if __name__ == "__main__":
	# 初始化分类器
	classifier = QueryClassifier(model_path="bert_query_classifier")

	# 训练模型
	classifier.train_model(data_file='../classify_data/model_generic_5000.json')
	# 示例预测
	test_queries = [
		"AI学科的课程大纲是什么",
		"JAVA课程费用多少？",
		"5*9等于多少？",
		"AI培训有哪些老师？"
	]
	for query in test_queries:
		category = classifier.predict_category(query)
		print(f"查询: {query} -> 分类: {category}")


